import os
import torch
import torchvision.transforms as v2
from torchvision import datasets


def get_miniwebvision(
    batch_size=128,
    num_workers=4,
    resize_image=224
):
    resize_image = resize_image
    
    train_transform=v2.Compose([ 
        v2.RandomResizedCrop((resize_image, resize_image)),
        v2.RandomHorizontalFlip(),
        v2.ToTensor(),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),    
        ])

    test_transform=v2.Compose([
        v2.Resize((resize_image, resize_image)),
        v2.ToTensor(),
        v2.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])     
        ])

    class_label_names = ["T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", "Vest", "Underwear"]

    data_train = datasets.ImageFolder(root=os.path.join('.', 'data', 'train'), transform=train_transform)
    data_test = datasets.ImageFolder(root=os.path.join('.', 'data', 'val'), transform=test_transform)
    
    train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)
    test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=100, shuffle=False, num_workers=4, pin_memory=True)
    
    return train_dataloader, test_dataloader, len(class_label_names)